{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 我的第一个对象化的图像识别模型\n",
    "\n",
    "### 了解keras函数式模型接口\n",
    "为什么叫“函数式模型”，请查看“Keras新手指南”的相关部分\n",
    "Keras的函数式模型为Model，即广义的拥有输入和输出的模型，我们使用Model来初始化一个函数式模型\n",
    "   - `from keras.models import Model`\n",
    "   - `from keras.layers import Input, Dense`\n",
    "   - `a = Input(shape=(32,))`\n",
    "   - `b = Dense(32)(a)`\n",
    "   - `model = Model(inputs=a, outputs=b)`\n",
    "\n",
    "在这里，我们的模型以a为输入，以b为输出，同样我们可以构造拥有多输入和多输出的模型\n",
    "- `model = Model(inputs=[a1, a2], outputs=[b1, b3, b3])`\n",
    "\n",
    "**常用Model属性**\n",
    "- `model.layers：组成模型图的各个层`\n",
    "- `model.inputs：模型的输入张量列表`\n",
    "- `model.outputs：模型的输出张量列表`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import Input, Dense, TimeDistributed\n",
    "from keras.models import Model\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelMnist():\n",
    "    def __init__(self, datapath):\n",
    "        self.datapath = datapath\n",
    "        self._model = self.CreatModel()\n",
    "        print(\"datapath = \", self.datapath)\n",
    "    \n",
    "    def CreatModel(self):\n",
    "        # This returns a tensor\n",
    "        inputs = Input(shape=(784,), name='main_input')\n",
    "        # a layer instance is callable on a tensor, and returns a tensor\n",
    "        a1 = Dense(64, activation='relu')(inputs)\n",
    "        a2 = Dense(64, activation='relu')(a1)\n",
    "        predictions = Dense(10, activation='softmax', name='main_output')(a2)\n",
    "\n",
    "        # This creates a model that includes\n",
    "        # the Input layer and three Dense layers\n",
    "        model = Model(inputs=inputs, outputs=predictions)\n",
    "        model.compile(optimizer='rmsprop',\n",
    "                      loss='categorical_crossentropy',\n",
    "                      metrics=['accuracy'])\n",
    "        return model\n",
    " \n",
    "    def SaveModel(self, filename='model/', comment='mnist'):\n",
    "        self._model.save_weights(filename + comment + '.model')\n",
    "        f = open('step.txt', 'w')\n",
    "        f.write(filename + comment)\n",
    "        f.close()\n",
    "    \n",
    "    def TrainModel(self):\n",
    "        #mnist = input_data.read_data_sets(self.datapath, one_hot=True)\n",
    "        train_input, train_label = mnist.train.next_batch(1000)\n",
    "        x_test, y_test = mnist.test.next_batch(50)\n",
    "        self._model.fit(train_input,train_label,\n",
    "                  batch_size=16,\n",
    "                  epochs=20)  # starts training\n",
    "        self.SaveModel()\n",
    "    \n",
    "    def LoadModel(self,filename='model/mnist.model'):\n",
    "        self._model.load_weights(filename)\n",
    "        \n",
    "    def TestModel(self,x):\n",
    "        self._model.predict(self, x, batch_size=32, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "datapath =  MNIST_data/\n",
      "Epoch 1/20\n",
      "1000/1000 [==============================] - 1s 749us/step - loss: 0.5164 - acc: 0.8960\n",
      "Epoch 2/20\n",
      "1000/1000 [==============================] - 0s 175us/step - loss: 0.2960 - acc: 0.9330\n",
      "Epoch 3/20\n",
      "1000/1000 [==============================] - 0s 173us/step - loss: 0.2109 - acc: 0.9500\n",
      "Epoch 4/20\n",
      "1000/1000 [==============================] - 0s 178us/step - loss: 0.1412 - acc: 0.9610 0s - loss: 0.1784 - acc: 0.9\n",
      "Epoch 5/20\n",
      "1000/1000 [==============================] - 0s 187us/step - loss: 0.1156 - acc: 0.9660\n",
      "Epoch 6/20\n",
      "1000/1000 [==============================] - 0s 179us/step - loss: 0.0821 - acc: 0.9780\n",
      "Epoch 7/20\n",
      "1000/1000 [==============================] - 0s 194us/step - loss: 0.0646 - acc: 0.9880\n",
      "Epoch 8/20\n",
      "1000/1000 [==============================] - 0s 185us/step - loss: 0.0440 - acc: 0.9950\n",
      "Epoch 9/20\n",
      "1000/1000 [==============================] - 0s 193us/step - loss: 0.0347 - acc: 0.9950\n",
      "Epoch 10/20\n",
      "1000/1000 [==============================] - 0s 189us/step - loss: 0.0306 - acc: 0.9970\n",
      "Epoch 11/20\n",
      "1000/1000 [==============================] - 0s 176us/step - loss: 0.0226 - acc: 0.9990\n",
      "Epoch 12/20\n",
      "1000/1000 [==============================] - 0s 192us/step - loss: 0.0236 - acc: 0.9980\n",
      "Epoch 13/20\n",
      "1000/1000 [==============================] - 0s 201us/step - loss: 0.0204 - acc: 0.9990\n",
      "Epoch 14/20\n",
      "1000/1000 [==============================] - 0s 180us/step - loss: 0.0179 - acc: 0.9990\n",
      "Epoch 15/20\n",
      "1000/1000 [==============================] - 0s 190us/step - loss: 0.0180 - acc: 0.9990\n",
      "Epoch 16/20\n",
      "1000/1000 [==============================] - 0s 188us/step - loss: 0.0171 - acc: 0.9990\n",
      "Epoch 17/20\n",
      "1000/1000 [==============================] - 0s 183us/step - loss: 0.0180 - acc: 0.9980\n",
      "Epoch 18/20\n",
      "1000/1000 [==============================] - 0s 184us/step - loss: 0.0165 - acc: 0.9990\n",
      "Epoch 19/20\n",
      "1000/1000 [==============================] - 0s 198us/step - loss: 0.0167 - acc: 0.9990\n",
      "Epoch 20/20\n",
      "1000/1000 [==============================] - 0s 178us/step - loss: 0.0163 - acc: 0.9990\n"
     ]
    }
   ],
   "source": [
    "p = ModelMnist('MNIST_data/')\n",
    "p.LoadModel()\n",
    "p.TrainModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img0 = mnist.train.images[1].reshape(28,28)\n",
    "fig = plt.figure(figsize=(10,10))\n",
    "ax0 = fig.add_subplot(221)\n",
    "ax0.imshow(img0)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
